# evaluate performance metrics

# load data and fit idealstan and other options to compare performance metrics


library(tidyverse)
library(modelr)
library(idealstan)
library(ggplot2)
library(mirt)
library(freqdom)
library(FactoMineR)

# load coronanet data, make a cross-fold validation 

# select 10 countries
# select folds based on number of time points (divide into 10 time periods)
# each fold contains all 10 countries
# evaluate based on each fold

# only use social distancing index

sd_data <- readRDS("coronanet/index_long_model_sd.rds") %>% 
  filter(country %in% c("United States of America",
                        "United Arab Emirates",
                        "Sweden",
                        "Zambia",
                        "South Africa",
                        "Panama",
                        "Portugal",
                        "New Zealand",
                        "China",
                        "India"))

sd_items <- c("allow_ann_event","buses","cancel_annual_event",
              "curfew_length","distance_other",
              "event_no_audience","int_restrict_all","int_restrict_border",
              "int_restrict_buses","int_restrict_cruises","int_restrict_ferries",
              "int_restrict_flights","int_restrict_NA","int_restrict_ports",
              "int_restrict_trains","number_mass","other_transport","postpone_ann_event",
              "postpone_rec_event","prison_pop","social_distance" ,"subways",
              "ox_mass_gathering","ox_public_transport","ox_pub_events",
              "ox_stay_home","ox_internal","other_transport")

model_type <- "sd"

# reduced DB size to 189k rows

# now split time series into equal 10-fold periods
# approximately 50 days in each series

days_split <- rep(1:10, each=50)

time_series_fold <- lapply(1:10, function(t) {
  
  this_time <- unique(sd_data$date_policy[1:500])[days_split==t]
  
  this_fold <- filter(sd_data, date_policy %in% this_time)
  
  to_make <- this_fold %>% 
    filter(item %in% sd_items$item) %>% 
    group_by(item) %>% 
    mutate(model_id=case_when(item=="ox_health_invest"~9,
                              model_type=="sd" & grepl(x=item,pattern="ox")~5,
                              grepl(x=item,pattern="ox")~3,
                              TRUE~9),
           var_cont=ifelse(model_id>5,pop_out,0)) %>% 
    group_by(item) %>% 
    mutate(var=ifelse(model_id %in% c(3,5) & min(var,na.rm=T)==0,var+1,var),
           min_item=ifelse(model_id==9,min(var_cont,na.rm=T),
                           min(var,na.rm=T))) %>% 
    ungroup %>% 
    mutate(ra_num=as.numeric(scale(ra_num))) %>% 
    group_by(item) %>% 
    mutate(var=ifelse(is.na(var) & !grepl(x=item,pattern="ox"),min(var,na.rm=T),var),
           var_cont=ifelse(is.na(var_cont) & item!="ox_health_invest",min(var_cont,na.rm=T),var_cont),
           var_cont=as.numeric(scale(var_cont))) %>% 
    group_by(country,item,date_policy) %>% 
    mutate(n_dup=n()) %>% 
    ungroup  
  
  to_ideal <- to_make %>% 
    #anti_join(days_no_change,by="date_policy") %>% 
    # anti_join(no_change,by="country") %>% 
    distinct %>% 
    mutate(var=as.integer(var),
           var=ifelse(model_id==9,0,var-1),
           var_cont=ifelse(is.nan(var_cont),0,var_cont),
           var_cont=ifelse(is.infinite(var_cont),0,var_cont)) %>% 
    filter(!(country %in% c("Samoa","Solomon Islands","Saint Kitts and Nevis",
                            "Liechtenstein","Montenegro","Northern Cyprus",
                            "North Macedonia","Nauru","Equatorial Guinea",
                            "Luxembourg","Malta","North Korea")),
           date_policy < ymd("2021-05-02"),
           !(item %in% c("allow_ann_event","postpone_rec_event"))) %>% 
    distinct %>% 
    id_make(
      outcome_disc="var",
      outcome_cont="var_cont",
      person_id="country",
      item_id="item",time_id="date_policy")
  
  # estimate idealstan model with this dataset
  
  activity_fit <- to_ideal %>% 
    id_estimate(vary_ideal_pts="random_walk",
                ncores=6,
                nchains=2,niters=350,
                save_warmup=TRUE,
                warmup=600,
                fixtype="prefix",pos_discrim = F,
                restrict_ind_high="social_distance",
                restrict_ind_low="number_mass",
                restrict_sd_low=3,restrict_var = T,
                time_fix_sd=0.01,
                map_over_id = "persons",
                person_sd=1,
                discrim_reg_sd = 5,
                diff_reg_sd = 5,
                #adapt_delta=0.95,
                het_var = F,
                fix_high=5,
                fix_low=0,
                time_center_cutoff = 650,
                time_var = 100,
                restrict_sd_high=.001,
                id_refresh = 100,
                const_type="items") 
  
  
    get_pred <- id_post_pred(activity_fit)
    
    # calculate variance by item
    
    var_explain_item <- lapply(get_pred, function(i) {

      X <- attr(i,"data")$outcome

      apply(i, 1, function(Xhat) {
        
            (1 - sum( (Xhat - X)**2 ) / sum(X**2))*100

        })

    })
    
    # aggregate data
    
    ag_X <- c(sapply(get_pred, function(i) {
      
      attr(i,"data")$outcome
      
    }))
    
    ag_Xhat <- c(sapply(get_pred, function(i) {
      
      apply(i, 2, mean)
      
    }))
    
    prop_var_idealstan <- (1 - sum( (ag_Xhat - ag_X)**2 ) / sum(ag_X**2))*100
    
    # check with our dpca function
    
    prepare_data <- mutate(activity_fit@score_data@score_matrix,
                           outcome=ifelse(model_id==5, outcome_disc,
                                          outcome_cont)) %>% 
      select(outcome, item_id, time_id, person_id) %>% 
      spread(key="item_id", value="outcome") %>% 
      select(-time_id,-person_id) %>% 
      as.matrix
    
    dpca_fit <- dpca(X=prepare_data,Ndpc=1)
  
    var.dpca = (1 - sum( (dpca_fit$Xhat - prepare_data)**2 ) / sum(prepare_data**2))*100
    
    var.dpca
    
    prepare_data_mirt_full <- mutate(activity_fit@score_data@score_matrix,
                                outcome=ifelse(model_id==5, outcome_disc,
                                               outcome_cont)) %>% 
      select(outcome, item_id, time_id, person_id) %>% 
      spread(key="item_id", value="outcome")
    
    prepare_data_mirt <- prepare_data_mirt_full %>% 
      select(-time_id,-person_id) %>% 
      mutate(across(matches("ox"), factor))
    
    # fit mixed factor model with grouping
    
    fit_factor <- FAMD(prepare_data_mirt,ncp=1)
    
    r_sq <- fit_factor$eig[2]
    
    # remove any with sd=0
    
    # prepare_data_mirt <- prepare_data_mirt[,sapply(1:ncol(prepare_data_mirt),function(i) {
    #   grepl(x=colnames(prepare_data_mirt)[i],pattern="ox") || length(unique(prepare_data_mirt[,i]))>2 
    # })]
    
    # can't fit mirt models with continuous outcome
    
    # mirt_fit <- multipleGroup(data=prepare_data_mirt,
    #                  group=prepare_data_mirt_full$person_id,
    #                  model=1,
    #                  itemtype=ifelse(grepl(x=colnames(prepare_data_mirt),
    #                                        pattern="ox"),
    #                                  '2PL',
    #                                  'spline'))
    
    tibble(r_sq=c(prop_var_idealstan,
                  var.dpca,
                  r_sq),
           model=c("idealstan",
                   "dpca",
                   "FAMD"),
           split=paste0(this_time[1], " to ",last(this_time)))
  
}) %>% bind_rows

time_series_fold %>% 
  mutate(r_sq=r_sq/100) %>% 
  ggplot(aes(y=r_sq,
             x=split)) +
  geom_line(aes(colour=model,
                linetype=model,
                group=model)) +
  scale_y_continuous(limits=c(0,1),labels=scales::percent) +
  ggthemes::theme_tufte() +
  labs(y="R2",x="Fold") +
  theme(axis.text.x=element_text(angle=90))

ggsave("model_comparison.pdf")
